#include "cartographer/mapping/internal/optimization/optimization_problem_2d.h"

#include <algorithm>
#include <array>
#include <cmath>
#include <map>
#include <memory>
#include <string>
#include <vector>

#include "cartographer/common/internal/ceres_solver_options.h"
#include "cartographer/common/histogram.h"
#include "cartographer/common/math.h"
#include "cartographer/mapping/internal/optimization/ceres_pose.h"
#include "cartographer/mapping/internal/optimization/cost_functions/landmark_cost_function_2d.h"
#include "cartographer/mapping/internal/optimization/cost_functions/spa_cost_function_2d.h"
#include "cartographer/sensor/odometry_data.h"
#include "cartographer/transform/transform.h"
#include "ceres/ceres.h"
#include "glog/logging.h"

namespace cartographer
{
  namespace mapping
  {
    namespace optimization
    {
      namespace
      {

        using ::cartographer::mapping::optimization::CeresPose;
        using LandmarkNode = ::cartographer::mapping::PoseGraphInterface::LandmarkNode;
        using TrajectoryData = ::cartographer::mapping::PoseGraphInterface::TrajectoryData;

        // For fixed frame pose.
        std::unique_ptr<transform::Rigid3d> Interpolate(
            const sensor::MapByTime<sensor::FixedFramePoseData> &map_by_time,
            const int trajectory_id, const common::Time time)
        {
          const auto it = map_by_time.lower_bound(trajectory_id, time);
          if (it == map_by_time.EndOfTrajectory(trajectory_id) ||
              !it->pose.has_value())
          {
            return nullptr;
          }
          if (it == map_by_time.BeginOfTrajectory(trajectory_id))
          {
            if (it->time == time)
            {
              return absl::make_unique<transform::Rigid3d>(it->pose.value());
            }
            return nullptr;
          }
          const auto prev_it = std::prev(it);
          if (prev_it->pose.has_value())
          {
            return absl::make_unique<transform::Rigid3d>(
                Interpolate(transform::TimestampedTransform{prev_it->time, prev_it->pose.value()},
                            transform::TimestampedTransform{it->time, it->pose.value()}, time)
                    .transform);
          }
          return nullptr;
        }

        // Converts a pose into the 3 optimization variable format used for Ceres:
        // translation in x and y, followed by the rotation angle representing the
        // orientation.
        std::array<double, 3> FromPose(const transform::Rigid2d &pose)
        {
          return {{pose.translation().x(), pose.translation().y(),
                   pose.normalized_angle()}};
        }

        // Converts a pose as represented for Ceres back to an transform::Rigid2d pose.
        transform::Rigid2d ToPose(const std::array<double, 3> &values)
        {
          return transform::Rigid2d({values[0], values[1]}, values[2]);
        }

        // Selects a trajectory node closest in time to the landmark observation and
        // applies a relative transform from it.
        transform::Rigid3d GetInitialLandmarkPose(
            const LandmarkNode::LandmarkObservation &observation,
            const NodeSpec2D &prev_node, const NodeSpec2D &next_node,
            const std::array<double, 3> &prev_node_pose,
            const std::array<double, 3> &next_node_pose)
        {
          const double interpolation_parameter =
              common::ToSeconds(observation.time - prev_node.time) /
              common::ToSeconds(next_node.time - prev_node.time);

          const std::tuple<std::array<double, 4>, std::array<double, 3>>
              rotation_and_translation =
                  InterpolateNodes2D(prev_node_pose.data(), prev_node.gravity_alignment,
                                     next_node_pose.data(), next_node.gravity_alignment,
                                     interpolation_parameter);
          return transform::Rigid3d::FromArrays(std::get<0>(rotation_and_translation),
                                                std::get<1>(rotation_and_translation)) *
                 observation.landmark_to_tracking_transform;
        }

        void AddLandmarkCostFunctions(
            const std::map<std::string, LandmarkNode> &landmark_nodes,
            const MapById<NodeId, NodeSpec2D> &node_data,
            MapById<NodeId, std::array<double, 3>> *C_nodes,
            std::map<std::string, CeresPose> *C_landmarks, ceres::Problem *problem,
            double huber_scale)
        {
          for (const auto &landmark_node : landmark_nodes)
          {
            for (const auto &observation : landmark_node.second.landmark_observations)
            {
              const std::string &landmark_id = landmark_node.first;
              const auto &begin_of_trajectory =
                  node_data.BeginOfTrajectory(observation.trajectory_id);
              // The landmark observation was made before the trajectory was created.
              if (observation.time < begin_of_trajectory->data.time)
              {
                continue;
              }
              // Find the trajectory nodes before and after the landmark observation.
              auto next =
                  node_data.lower_bound(observation.trajectory_id, observation.time);
              // The landmark observation was made, but the next trajectory node has
              // not been added yet.
              if (next == node_data.EndOfTrajectory(observation.trajectory_id))
              {
                continue;
              }
              if (next == begin_of_trajectory)
              {
                next = std::next(next);
              }
              auto prev = std::prev(next);
              // Add parameter blocks for the landmark ID if they were not added before.
              std::array<double, 3> *prev_node_pose = &C_nodes->at(prev->id);
              std::array<double, 3> *next_node_pose = &C_nodes->at(next->id);
              if (!C_landmarks->count(landmark_id))
              {
                const transform::Rigid3d starting_point =
                    landmark_node.second.global_landmark_pose.has_value()
                        ? landmark_node.second.global_landmark_pose.value()
                        : GetInitialLandmarkPose(observation, prev->data, next->data,
                                                 *prev_node_pose, *next_node_pose);
                C_landmarks->emplace(
                    landmark_id,
                    CeresPose(starting_point, nullptr /* translation_parametrization */,
                              absl::make_unique<ceres::QuaternionParameterization>(),
                              problem));
                // Set landmark constant if it is frozen.
                if (landmark_node.second.frozen)
                {
                  problem->SetParameterBlockConstant(
                      C_landmarks->at(landmark_id).translation());
                  problem->SetParameterBlockConstant(
                      C_landmarks->at(landmark_id).rotation());
                }
              }
              problem->AddResidualBlock(
                  LandmarkCostFunction2D::CreateAutoDiffCostFunction(
                      observation, prev->data, next->data),
                  new ceres::HuberLoss(huber_scale), prev_node_pose->data(),
                  next_node_pose->data(), C_landmarks->at(landmark_id).rotation(),
                  C_landmarks->at(landmark_id).translation());
            }
          }
        }

      } // namespace

      OptimizationProblem2D::OptimizationProblem2D(
          const proto::OptimizationProblemOptions &options)
          : options_(options) {}

      OptimizationProblem2D::~OptimizationProblem2D() {}

      void OptimizationProblem2D::AddImuData(const int trajectory_id,
                                             const sensor::ImuData &imu_data)
      {
        // IMU data is not used in 2D optimization, so we ignore this part of the
        // interface.
      }

      void OptimizationProblem2D::AddOdometryData(
          const int trajectory_id, const sensor::OdometryData &odometry_data)
      {
        odometry_data_.Append(trajectory_id, odometry_data);
      }

      void OptimizationProblem2D::AddFixedFramePoseData(
          const int trajectory_id,
          const sensor::FixedFramePoseData &fixed_frame_pose_data)
      {
        fixed_frame_pose_data_.Append(trajectory_id, fixed_frame_pose_data);
      }

      void OptimizationProblem2D::AddTrajectoryNode(const int trajectory_id,
                                                    const NodeSpec2D &node_data)
      {
        node_data_.Append(trajectory_id, node_data);
        trajectory_data_[trajectory_id];
      }

      void OptimizationProblem2D::SetTrajectoryData(
          int trajectory_id, const TrajectoryData &trajectory_data)
      {
        trajectory_data_[trajectory_id] = trajectory_data;
      }

      void OptimizationProblem2D::InsertTrajectoryNode(const NodeId &node_id,
                                                       const NodeSpec2D &node_data)
      {
        node_data_.Insert(node_id, node_data);
        trajectory_data_[node_id.trajectory_id];
      }

      void OptimizationProblem2D::TrimTrajectoryNode(const NodeId &node_id)
      {
        empty_imu_data_.Trim(node_data_, node_id);
        odometry_data_.Trim(node_data_, node_id);
        fixed_frame_pose_data_.Trim(node_data_, node_id);
        node_data_.Trim(node_id);
        if (node_data_.SizeOfTrajectoryOrZero(node_id.trajectory_id) == 0)
        {
          trajectory_data_.erase(node_id.trajectory_id);
        }
      }

      void OptimizationProblem2D::AddSubmap(
          const int trajectory_id, const transform::Rigid2d &global_submap_pose)
      {
        submap_data_.Append(trajectory_id, SubmapSpec2D{global_submap_pose});
      }

      void OptimizationProblem2D::InsertSubmap(
          const SubmapId &submap_id, const transform::Rigid2d &global_submap_pose)
      {
        submap_data_.Insert(submap_id, SubmapSpec2D{global_submap_pose});
      }

      void OptimizationProblem2D::TrimSubmap(const SubmapId &submap_id)
      {
        submap_data_.Trim(submap_id);
      }

      void OptimizationProblem2D::SetMaxNumIterations(
          const int32 max_num_iterations)
      {
        options_.mutable_ceres_solver_options()->set_max_num_iterations(
            max_num_iterations);
      }

      void OptimizationProblem2D::Solve(
          const std::vector<Constraint> &constraints,
          const std::map<int, PoseGraphInterface::TrajectoryState> &
              trajectories_state,
          const std::map<std::string, LandmarkNode> &landmark_nodes)
      {
        if (node_data_.empty())
        {
          // Nothing to optimize.
          return;
        }

        std::set<int> frozen_trajectories;
        for (const auto &it : trajectories_state)
        {
          if (it.second == PoseGraphInterface::TrajectoryState::FROZEN)
          {
            frozen_trajectories.insert(it.first);
          }
        }

        ceres::Problem::Options problem_options;
        ceres::Problem problem(problem_options);

        // Set the starting point.
        // TODO(hrapp): Move ceres data into SubmapSpec.
        MapById<SubmapId, std::array<double, 3>> C_submaps;
        MapById<NodeId, std::array<double, 3>> C_nodes;
        std::map<std::string, CeresPose> C_landmarks;
        bool first_submap = true;
        for (const auto &submap_id_data : submap_data_)
        {
          const bool frozen =
              frozen_trajectories.count(submap_id_data.id.trajectory_id) != 0;
          C_submaps.Insert(submap_id_data.id,
                           FromPose(submap_id_data.data.global_pose));
          problem.AddParameterBlock(C_submaps.at(submap_id_data.id).data(), 3);
          if (first_submap || frozen)
          {
            first_submap = false;
            // Fix the pose of the first submap or all submaps of a frozen
            // trajectory.
            problem.SetParameterBlockConstant(C_submaps.at(submap_id_data.id).data());
          }
        }
        for (const auto &node_id_data : node_data_)
        {
          const bool frozen =
              frozen_trajectories.count(node_id_data.id.trajectory_id) != 0;
          C_nodes.Insert(node_id_data.id, FromPose(node_id_data.data.global_pose_2d));
          problem.AddParameterBlock(C_nodes.at(node_id_data.id).data(), 3);
          if (frozen)
          {
            problem.SetParameterBlockConstant(C_nodes.at(node_id_data.id).data());
          }
        }
        // Add cost functions for intra- and inter-submap constraints.
        for (const Constraint &constraint : constraints)
        {
          problem.AddResidualBlock(
              CreateAutoDiffSpaCostFunction(constraint.pose),
              // Loop closure constraints should have a loss function.
              constraint.tag == Constraint::INTER_SUBMAP
                  ? new ceres::HuberLoss(options_.huber_scale())
                  : nullptr,
              C_submaps.at(constraint.submap_id).data(),
              C_nodes.at(constraint.node_id).data());
        }
        // Add cost functions for landmarks.
        AddLandmarkCostFunctions(landmark_nodes, node_data_, &C_nodes, &C_landmarks,
                                 &problem, options_.huber_scale());
        // Add penalties for violating odometry or changes between consecutive nodes
        // if odometry is not available.
        for (auto node_it = node_data_.begin(); node_it != node_data_.end();)
        {
          const int trajectory_id = node_it->id.trajectory_id;
          const auto trajectory_end = node_data_.EndOfTrajectory(trajectory_id);
          if (frozen_trajectories.count(trajectory_id) != 0)
          {
            node_it = trajectory_end;
            continue;
          }

          auto prev_node_it = node_it;
          for (++node_it; node_it != trajectory_end; ++node_it)
          {
            const NodeId first_node_id = prev_node_it->id;
            const NodeSpec2D &first_node_data = prev_node_it->data;
            prev_node_it = node_it;
            const NodeId second_node_id = node_it->id;
            const NodeSpec2D &second_node_data = node_it->data;

            if (second_node_id.node_index != first_node_id.node_index + 1)
            {
              continue;
            }

            // Add a relative pose constraint based on the odometry (if available).
            std::unique_ptr<transform::Rigid3d> relative_odometry =
                CalculateOdometryBetweenNodes(trajectory_id, first_node_data,
                                              second_node_data);
            if (relative_odometry != nullptr)
            {
              problem.AddResidualBlock(
                  CreateAutoDiffSpaCostFunction(Constraint::Pose{
                      *relative_odometry, options_.odometry_translation_weight(),
                      options_.odometry_rotation_weight()}),
                  nullptr /* loss function */, C_nodes.at(first_node_id).data(),
                  C_nodes.at(second_node_id).data());
            }

            // Add a relative pose constraint based on consecutive local SLAM poses.
            const transform::Rigid3d relative_local_slam_pose =
                transform::Embed3D(first_node_data.local_pose_2d.inverse() *
                                   second_node_data.local_pose_2d);
            problem.AddResidualBlock(
                CreateAutoDiffSpaCostFunction(
                    Constraint::Pose{relative_local_slam_pose,
                                     options_.local_slam_pose_translation_weight(),
                                     options_.local_slam_pose_rotation_weight()}),
                nullptr /* loss function */, C_nodes.at(first_node_id).data(),
                C_nodes.at(second_node_id).data());
          }
        }

        std::map<int, std::array<double, 3>> C_fixed_frames;
        for (auto node_it = node_data_.begin(); node_it != node_data_.end();)
        {
          const int trajectory_id = node_it->id.trajectory_id;
          const auto trajectory_end = node_data_.EndOfTrajectory(trajectory_id);
          if (!fixed_frame_pose_data_.HasTrajectory(trajectory_id))
          {
            node_it = trajectory_end;
            continue;
          }

          const TrajectoryData &trajectory_data = trajectory_data_.at(trajectory_id);
          bool fixed_frame_pose_initialized = false;
          for (; node_it != trajectory_end; ++node_it)
          {
            const NodeId node_id = node_it->id;
            const NodeSpec2D &node_data = node_it->data;

            const std::unique_ptr<transform::Rigid3d> fixed_frame_pose =
                Interpolate(fixed_frame_pose_data_, trajectory_id, node_data.time);
            if (fixed_frame_pose == nullptr)
            {
              continue;
            }

            const Constraint::Pose constraint_pose{
                *fixed_frame_pose, options_.fixed_frame_pose_translation_weight(),
                options_.fixed_frame_pose_rotation_weight()};

            if (!fixed_frame_pose_initialized)
            {
              transform::Rigid2d fixed_frame_pose_in_map;
              if (trajectory_data.fixed_frame_origin_in_map.has_value())
              {
                fixed_frame_pose_in_map = transform::Project2D(
                    trajectory_data.fixed_frame_origin_in_map.value());
              }
              else
              {
                fixed_frame_pose_in_map =
                    node_data.global_pose_2d *
                    transform::Project2D(constraint_pose.zbar_ij).inverse();
              }

              C_fixed_frames.emplace(trajectory_id,
                                     FromPose(fixed_frame_pose_in_map));
              fixed_frame_pose_initialized = true;
            }

            problem.AddResidualBlock(
                CreateAutoDiffSpaCostFunction(constraint_pose),
                options_.fixed_frame_pose_use_tolerant_loss()
                    ? new ceres::TolerantLoss(
                          options_.fixed_frame_pose_tolerant_loss_param_a(),
                          options_.fixed_frame_pose_tolerant_loss_param_b())
                    : nullptr,
                C_fixed_frames.at(trajectory_id).data(), C_nodes.at(node_id).data());
          }
        }

        // Solve.
        ceres::Solver::Summary summary;
        ceres::Solve(
            common::CreateCeresSolverOptions(options_.ceres_solver_options()),
            &problem, &summary);
        if (options_.log_solver_summary())
        {
          LOG(INFO) << summary.FullReport();
        }

        // Store the result.
        for (const auto &C_submap_id_data : C_submaps)
        {
          submap_data_.at(C_submap_id_data.id).global_pose =
              ToPose(C_submap_id_data.data);
        }
        for (const auto &C_node_id_data : C_nodes)
        {
          node_data_.at(C_node_id_data.id).global_pose_2d =
              ToPose(C_node_id_data.data);
        }
        for (const auto &C_fixed_frame : C_fixed_frames)
        {
          trajectory_data_.at(C_fixed_frame.first).fixed_frame_origin_in_map =
              transform::Embed3D(ToPose(C_fixed_frame.second));
        }
        for (const auto &C_landmark : C_landmarks)
        {
          landmark_data_[C_landmark.first] = C_landmark.second.ToRigid();
        }
      }

      std::unique_ptr<transform::Rigid3d> OptimizationProblem2D::InterpolateOdometry(
          const int trajectory_id, const common::Time time) const
      {
        const auto it = odometry_data_.lower_bound(trajectory_id, time);
        if (it == odometry_data_.EndOfTrajectory(trajectory_id))
        {
          return nullptr;
        }
        if (it == odometry_data_.BeginOfTrajectory(trajectory_id))
        {
          if (it->time == time)
          {
            return absl::make_unique<transform::Rigid3d>(it->pose);
          }
          return nullptr;
        }
        const auto prev_it = std::prev(it);
        return absl::make_unique<transform::Rigid3d>(
            Interpolate(transform::TimestampedTransform{prev_it->time, prev_it->pose},
                        transform::TimestampedTransform{it->time, it->pose}, time)
                .transform);
      }

      std::unique_ptr<transform::Rigid3d>
      OptimizationProblem2D::CalculateOdometryBetweenNodes(
          const int trajectory_id, const NodeSpec2D &first_node_data,
          const NodeSpec2D &second_node_data) const
      {
        if (odometry_data_.HasTrajectory(trajectory_id))
        {
          const std::unique_ptr<transform::Rigid3d> first_node_odometry =
              InterpolateOdometry(trajectory_id, first_node_data.time);
          const std::unique_ptr<transform::Rigid3d> second_node_odometry =
              InterpolateOdometry(trajectory_id, second_node_data.time);
          if (first_node_odometry != nullptr && second_node_odometry != nullptr)
          {
            transform::Rigid3d relative_odometry =
                transform::Rigid3d::Rotation(first_node_data.gravity_alignment) *
                first_node_odometry->inverse() * (*second_node_odometry) *
                transform::Rigid3d::Rotation(
                    second_node_data.gravity_alignment.inverse());
            return absl::make_unique<transform::Rigid3d>(relative_odometry);
          }
        }
        return nullptr;
      }

    } // namespace optimization
  }   // namespace mapping
} // namespace cartographer
